"""
数据集导入
以及可以绘制用户想看的数据曲线
依旧使用命令行控制
既可以直接接入，也可以自由控制
"""
from glob import glob
import math

from loguru import logger
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset


class LoadData(Dataset):
    def __init__(
        self,
        data_source: str,
        space: int = 1460,
        class_name: str = "收盘",
        time_step: int = 30,
        out_days: int = 1,
        test_day_split=90,
        train_mode: str = "train",
    ):
        """
        space: 获取最后space天数据进行训练与测试，过分久远的数据对于训练无意义，默认为四年（1460天）
        data_source: 数据源，就是csv的文件路径
        train_mode: 训练模式。可以为：train test
        class_name: 使用哪个数据进行训练：可选：开盘,最高,最低,收盘,成交量,成交额
        time_step: 用多少天预测 默认输入30天
        out_days: 输出预测多少天 默认只预测未来1天
        test_day_split: 获取多少天用于测试，默认是90天。如果是每三十天预测1天，也能预测60天份的。
                        注意：测试集是 第0天到第test_day_split - 1天
        """
        self.train_mode = train_mode
        self.time_step = time_step
        self.out_days = out_days
        self.class_name = class_name
        self.data_source = data_source
        self.space = space
        self.test_day = math.ceil(test_day_split / out_days) * out_days
        self.data = pd.read_csv(self.data_source)
        self.flow_norm, self.flow_data = LoadData.pre_process_data(self.data[self.class_name].astype(np.float64).values[:, np.newaxis])
        self.flow_data =  self.flow_data[-space:]
        # self.use_data = (
        #     self.data[self.class_name][-space:].astype(np.float64).values[:, np.newaxis]
        # )
        self.data_time = self.data["日期"][-space:]  # 获取训练集与测试集的时间
        self.train_data = self.flow_data[self.test_day :]  # 为了准确度，要将今日也纳入训练集
        
        # 测试集内容：
        # 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
        # 假设用五天推断后五天，15天的数据就可以预测两次：
        # 1 2 3 4 5 -> 6 7 8 9 10
        # 6 7 8 9 10 ->  11 12 13 14 15
        self.test_data = self.flow_data[: self.test_day]  # 测试集只是为了测试,用以前的就行
        # print(self.test_data,"哈哈哈")
        self.train_data_time = self.data_time[self.test_day :]
        
        self.test_data_time = self.data_time[: self.test_day]
        # print(self.test_data_time,"嘿嘿")
        # print(self.flow_norm,self.flow_norm[0] -self.flow_norm[1],"吼吼")
        # print(self.test_data)
        if len(self.test_data) < 90:
            raise ValueError("测试数据数量过少，请重新划分数据集")

    def __len__(self):
        if self.train_mode == "train":
            return len(self.train_data) - self.time_step - (self.out_days - 1)
        elif self.train_mode == "test":
            # （天 - 用于预测的时间长度）/输出时间长度
            return int((len(self.test_data) - self.time_step) / self.out_days)
        else:
            raise ValueError(" train mode error")

    def __getitem__(self, index):
        if self.train_mode == "train":
            my_data = self.train_data[index : index + self.time_step]
            my_label = self.train_data[index + self.time_step : index + self.time_step + self.out_days]
        elif self.train_mode == "test":
            my_data = self.test_data[index * self.out_days : index * self.out_days + self.time_step]
            my_label = self.test_data[index * self.out_days + self.time_step : index * self.out_days + self.time_step + self.out_days]
        else:
            raise ValueError("train mode error")
        my_data = LoadData.to_tensor(my_data)
        my_label = LoadData.to_tensor(my_label)
        return {"my_data": my_data, "my_label": my_label}

    # 数据预处理
    @staticmethod
    def pre_process_data(data):
        norm_base = LoadData.normalized_base(data)
        normalized_data = LoadData.normalized_data(data, norm_base[0], norm_base[1])
        return norm_base, normalized_data

    # 生成原始数据中最大值与最小值
    @staticmethod
    def normalized_base(data):
        max_data = np.max(data, keepdims=True)  # keepdims保持维度不变
        min_data = np.min(data, keepdims=True)
        # max_data.shape  --->(1, 1)
        return max_data, min_data

    # 对数据进行标准化
    @staticmethod
    def normalized_data(data, max_data, min_data):
        data_base = max_data - min_data
        normalized_data = (data - min_data) / data_base
        return normalized_data

    @staticmethod
    # 反标准化  在评价指标误差以及画图的使用使用
    def recoverd_data(data, max_data, min_data):
        data_base = max_data - min_data
        recoverd_data = data * data_base + min_data
        return recoverd_data

    @staticmethod
    def to_tensor(data):
        return torch.tensor(data, dtype=torch.float)


if __name__ == "__main__":
    # train_data = LoadData("datasets/600000_浦发银行.csv")
    test_data = LoadData("datasets/600000_浦发银行.csv", train_mode="test")
    # train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
    # print(len(train_loader))
    # print(len(test_loader))